from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import cv2
from scipy import misc
import tensorflow as tf
import numpy as np
import os
import facenet
import align.detect_face

with tf.Graph().as_default():
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4, allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False, ))
    with sess.as_default():
        pnet, rnet, onet = align.detect_face.create_mtcnn(sess, None)


# # def load_facenet():
#     # with tf.Graph().as_default():
# # sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False))
#     # return sess
#             # sess = tf.Session()
#             # Load the model
#             # 这里要改为自己的模型位置
# # model = r'D:\CODE\GITCODE\face_code\facenet_facerecognition\20170512-110547'
# # facenet.load_model(model)
# # # Get input and output tensors
# # images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
# # embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
# # phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
#
#
# # 修改版load_and_align_data
# # 传入rgb np.ndarray
def load_and_align_data(img, image_size, margin):
    minsize = 20  # minimum size of face
    threshold = [0.6, 0.7, 0.7]  # three steps's threshold
    factor = 0.709  # scale factor

    img_size = np.asarray(img.shape)[0:2]

    # bounding_boxes shape:(1,5)  type:np.ndarray
    bounding_boxes, _ = align.detect_face.detect_face(img, minsize, pnet, rnet, onet, threshold, factor)

    # 如果未发现目标 直接返回
    if len(bounding_boxes) < 1:
        return 0, 0, 0

    # 从数组的形状中删除单维度条目，即把shape中为1的维度去掉
    # det = np.squeeze(bounding_boxes[:,0:4])
    det = bounding_boxes

    # print('det shape type')
    # print(det.shape)
    # print(type(det))

    det[:, 0] = np.maximum(det[:, 0] - margin / 2, 0)
    det[:, 1] = np.maximum(det[:, 1] - margin / 2, 0)
    det[:, 2] = np.minimum(det[:, 2] + margin / 2, img_size[1] - 1)
    det[:, 3] = np.minimum(det[:, 3] + margin / 2, img_size[0] - 1)

    det = det.astype(int)
    crop = []
    for i in range(len(bounding_boxes)):
        temp_crop = img[det[i, 1]:det[i, 3], det[i, 0]:det[i, 2], :]
        aligned = misc.imresize(temp_crop, (image_size, image_size), interp='bilinear')
        prewhitened = facenet.prewhiten(aligned)
        crop.append(prewhitened)

    # np.stack 将crop由一维list变为二维
    crop_image = np.stack(crop)

    return 1, det, crop_image


#
# # 加载人脸库
# def load_faces():
#
#     image = []
#     nrof_images = 0
#     # 这里要改为自己emb_img文件夹的位置
#     emb_dir = r'D:\CODE\GITCODE\face_code\facenet_facerecognition\train_dir\emb_img'
#     all_obj = []
#     for i in os.listdir(emb_dir):
#         all_obj.append(i)
#         img = misc.imread(os.path.join(emb_dir, i), mode='RGB')  # 加载
#         prewhitened = facenet.prewhiten(img)
#         image.append(prewhitened)
#         nrof_images = nrof_images + 1
#     images = np.stack(image)
#     feed_dict = {images_placeholder: images, phase_train_placeholder: False}
#     compare_emb = sess.run(embeddings, feed_dict=feed_dict)
#     compare_num = len(compare_emb)
#     # 保存compare_emb到数据库，下次直接加载
#     return compare_num
#
#
# # 检测人脸
# def detect_image(rgb_frame):
#     mark, bounding_box, crop_image = load_and_align_data(rgb_frame, 160, 44)
#     return mark, bounding_box, crop_image
#
#
# # 获取特征
# def get_image_feature(crop_image):
#     feed_dict = {images_placeholder: crop_image, phase_train_placeholder: False}
#     emb = sess.run(embeddings, feed_dict=feed_dict)
#     return emb
#
#
# def compare_face(emb_img, all_img):
#     fin_obj = []
#     for i in range(len(emb_img)):
#         dist_list = []
#         for j in range(all_img):
#             dist = np.sqrt(np.sum(np.square(np.subtract(emb_img[i, :], all_img[j, :]))))
#             dist_list.append(dist)
#         min_value = min(dist_list)
#         if (min_value > 0.65):
#             fin_obj.append('unknow')
#         else:
#             fin_obj.append(all_img[dist_list.index(min_value)])
#     return fin_obj
#
#
# def main():
#     all_num = load_faces()  # 加载人脸库
#     # 开启ip摄像头
#     # video = "http://admin:admin@192.168.199.102:8081/"  # 此处@后的ipv4 地址需要修改为自己的地址
#     video = 0
#     # 参数为0表示打开内置摄像头，参数是视频文件路径则打开视频
#     capture = cv2.VideoCapture(video)
#     # capture = cv2.VideoCapture(0)
#     cv2.namedWindow("camera", 1)
#     timer = 0
#     while True:
#         ret, frame = capture.read()
#
#         # rgb frame np.ndarray 480*640*3
#         rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#
#         # 获取 判断标识 bounding_box crop_image  人脸检测
#         # mark, bounding_box, crop_image = load_and_align_data(rgb_frame, 160, 44)
#         mark, bounding_box, crop_image = detect_image(rgb_frame)
#         if (mark):
#             emb = get_image_feature(crop_image)  # 特征提取
#             temp_num = len(emb)
#             # 为bounding_box 匹配标签
#             fin_obj = compare_face(emb, all_num)
#             # 在frame上绘制边框和文字
#             for rec_position in range(temp_num):
#                 cv2.rectangle(frame, (bounding_box[rec_position, 0], bounding_box[rec_position, 1]),
#                               (bounding_box[rec_position, 2], bounding_box[rec_position, 3]), (0, 255, 0),
#                               2, 8, 0)
#
#                 cv2.putText(
#                     frame,
#                     fin_obj[rec_position],
#                     (bounding_box[rec_position, 0], bounding_box[rec_position, 1]),
#                     cv2.FONT_HERSHEY_COMPLEX_SMALL,
#                     0.8,
#                     (0, 0, 255),
#                     thickness=2,
#                     lineType=2)
#
#             cv2.imshow('camera', frame)
#
#         # cv2.imshow('camera',frame)
#         key = cv2.waitKey(10)
#         if key == 27:
#             # esc键退出
#             print("esc break...")
#             break
#     # When everything is done, release the capture
#     capture.release()
#     cv2.destroyWindow("camera")


def main_bak():
    with tf.Graph().as_default():
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6, allow_growth=True)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False, )) as sess:
            # Load the model
            # 这里要改为自己的模型位置
            model = r'D:\CODE\GITCODE\face_code\facenet_facerecognition\20170512-110547'
            facenet.load_model(model)

            # Get input and output tensors
            images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
            embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
            phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")

            image = []
            nrof_images = 0

            # 这里要改为自己emb_img文件夹的位置
            emb_dir = r'D:\CODE\GITCODE\face_code\facenet_facerecognition\train_dir\emb_img'
            all_obj = []
            for i in os.listdir(emb_dir):
                all_obj.append(i)
                img = misc.imread(os.path.join(emb_dir, i), mode='RGB')
                prewhitened = facenet.prewhiten(img)
                image.append(prewhitened)
                nrof_images = nrof_images + 1

            images = np.stack(image)
            # print("--------->>>>>>>>")
            # print("---------<<<<<<<<")
            feed_dict = {images_placeholder: images, phase_train_placeholder: False}
            compare_emb = sess.run(embeddings, feed_dict=feed_dict)
            compare_num = len(compare_emb)
            # --------------------------------------------------
            # 开启ip摄像头
            # video = "http://admin:admin@192.168.199.102:8081/"  # 此处@后的ipv4 地址需要修改为自己的地址
            video = 0
            # 参数为0表示打开内置摄像头，参数是视频文件路径则打开视频
            capture = cv2.VideoCapture(video)
            # capture = cv2.VideoCapture(0)
            cv2.namedWindow("camera", 1)
            timer = 0
            while True:
                ret, frame = capture.read()

                # rgb frame np.ndarray 480*640*3
                rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

                # 获取 判断标识 bounding_box crop_image  人脸检测
                mark, bounding_box, crop_image = load_and_align_data(rgb_frame, 160, 44)
                # print("----经过人脸检测的图片")
                # print(crop_image)
                # print("----------")
                timer += 1
                if (1):
                    # print(timer)
                    if (mark):
                        feed_dict = {images_placeholder: crop_image, phase_train_placeholder: False}
                        emb = sess.run(embeddings, feed_dict=feed_dict)
                        # print("----检测的图片的128特征值")
                        # print(emb, type(emb))
                        # print("进行比对")
                        temp_num = len(emb)
                        # print(temp_num)

                        fin_obj = []
                        # print(all_obj)

                        # 为bounding_box 匹配标签
                        for i in range(temp_num):
                            dist_list = []
                            for j in range(compare_num):
                                dist = np.sqrt(np.sum(np.square(np.subtract(emb[i, :], compare_emb[j, :]))))
                                dist_list.append(dist)
                            min_value = min(dist_list)
                            if (min_value > 0.5):
                                fin_obj.append('unknow')
                            else:
                                print("--------{}--------".format(min_value))
                                fin_obj.append(all_obj[dist_list.index(min_value)])

                                # 在frame上绘制边框和文字
                        for rec_position in range(temp_num):
                            cv2.rectangle(frame, (bounding_box[rec_position, 0], bounding_box[rec_position, 1]),
                                          (bounding_box[rec_position, 2], bounding_box[rec_position, 3]), (0, 255, 0),
                                          2, 8, 0)

                            cv2.putText(
                                frame,
                                fin_obj[rec_position],
                                (bounding_box[rec_position, 0], bounding_box[rec_position, 1]),
                                cv2.FONT_HERSHEY_COMPLEX_SMALL,
                                0.8,
                                (0, 0, 255),
                                thickness=2,
                                lineType=2)

                        cv2.imshow('camera', frame)

                # cv2.imshow('camera',frame)
                key = cv2.waitKey(10)
                if key == 27:
                    # esc键退出
                    print("esc break...")
                    break

            # if key == ord(' '):
            #     # 保存一张图像
            #     num = num+1
            #     filename = "frames_%s.jpg" % num
            #     cv2.imwrite(filename,frame)

            # When everything is done, release the capture
            capture.release()
            cv2.destroyWindow("camera")

        # When everything is done, release the capture


# 创建load_and_align_data网络
print('Creating networks and loading parameters')

if __name__ == '__main__':
    main_bak()
